import torch
import clip  # openai/CLIP
from PIL import Image
from .my_reward import RewardModel, calculate_batch_scores


def reward_fn(dtype, device):
    # 使用 openai 的 clip 包加载模型和预处理器
    clip_model, preprocess = clip.load("ViT-L/14", device=device)

    # 初始化你自己的 reward model
    reward_model = RewardModel(embed_dim=768).to(device)

    # 加载你训练好的权重
    checkpoint_path = "../reward_model/ckpts/RM_old_3_sd35_collision_10.pt"
    checkpoint = torch.load(checkpoint_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    reward_model.load_state_dict(new_state_dict)

    # 定义最终要返回的函数
    def _fn(images, prompts, metadata=None):
        # 注意：images 应该是 numpy array 格式的图像（范围 0-255），形状为 (B, H, W, C)
        scores = calculate_batch_scores(prompts, images, preprocess, clip_model, reward_model, device)
        return scores * 10, {}

    return _fn



def aesthetic_score(dtype, device):
    from .aesthetic_scorer import AestheticScorer

    scorer = AestheticScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    def _fn(images, prompts, metadata=None):
        scores = scorer(images)
        print(scores.size())
        return scores, {}

    return _fn


def ImageReward(dtype, device):
    from .ImageReward_scorer import ImageRewardScorer

    scorer = ImageRewardScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    def _fn(images, prompts, metadata=None):
        scores = scorer(images, prompts)
        return scores, {}

    return _fn


def hpsv2(dtype, device):
    from .hpsv2_scorer import HPSv2Scorer

    scorer = HPSv2Scorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    def _fn(images, prompts, metadata=None):
        scores = scorer(images, prompts)
        return scores, {}

    return _fn


def PickScore(dtype, device):
    from .PickScore_scorer import PickScoreScorer

    scorer = PickScoreScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    def _fn(images, prompts, metadata=None):
        scores = scorer(images, prompts)
        return scores, {}

    return _fn
